{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6196c6f2-f088-4c28-b9a8-e921f9a7465d",
   "metadata": {},
   "source": [
    "# Custom Tokenization for DAPT (Domain Adaptive Pre-Training)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbd33f30-2a18-480f-a7f1-210ac99b937c",
   "metadata": {},
   "source": [
    "This notebook walks through the custom tokenization workflow required for DAPT (Domain Adaptive Pre-training) as shown in the schematic diagram below. \n",
    "\n",
    "![pipeline](imgs/tokenization_diagram.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56f509da",
   "metadata": {},
   "source": [
    "### Custom Tokenization Workflow"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12fe579a",
   "metadata": {},
   "source": [
    "#### Goal\n",
    "Given a pre-trained tokenizer trained on general purpose datasets (<b>Original Tokenizer</b>), our goal is to adapt it to a given domain that we want to apply it to (in this notebook, the example domain we are looking at is ChipDesign).\n",
    "\n",
    "When adapting a pre-trained tokenizer to a given domain, the main goals are to improve tokenization efficiency on domain-specific data, maintain efficiency and language model performance on general purpose datasets, and minimize the effort for retraining/fine-tuning. Since we don't have access to the entire general purpose data used for pretraining the original tokenizer, we want to preserve the existing token mappings, and any new tokens that are added should be strictly an \"extension\". \n",
    "\n",
    "Generally, when adapting tokenizer to domain-specific data, the goal is to create a tokenizer that is better suited to the vocabulary and structure of that specific domain. This can improve the efficiency and performance of the model on tasks within that domain through efficient representation of domain specific information.\n",
    "\n",
    "#### Approach \n",
    "The general approach we adopt is to train a <b>Domain Specific Tokenizer</b> from scratch on domain data and use it to identify domain specific tokens that are missing from the original tokenizer. This is done by simply comparing the vocabs of the Original Tokenizer and the newly trained Domain Specific Tokenizer. The missing domain specific tokens are then added to the original tokenizer for extending it to get the final <b>Domain Adapted Tokenizer</b>. \n",
    "\n",
    "#### Tradeoff \n",
    "However, there is a tradeoff to adding missing domain specific tokens to the Original Tokenizer. The challenge is to balance this tradeoff between tokenization efficiency on domain data vs disturbance to the performance on general-purpose data as a result of adding domain specific tokens to the Original Tokenizer.\n",
    "\n",
    "For instance, addition of a large no. of domain specific tokens can lead to higher efficiency on domain specific data, but DAPT process might take longer since it would take longer for the loss to converge​ due to disturbance of efficiency/performance on the general purpose data.\n",
    "\n",
    "On the other hand, addition of only a small no. of domain specific tokens can lead to maintained efficiency on general purpose data, but may lack coverage on the domain specific dataset​.\n",
    "\n",
    "#### Balancing The Tradeoff\n",
    "To balance this tradeoff, instead of adding all identified missing domain specific tokens to the original tokenizer, we identify the most frequently occuring tokens using a threshold and only add the ones with usage frequencies above the given threshold to get the final Domain Adapted Tokenizer. \n",
    "\n",
    "For identifying the most frequently used tokens, we first extend the Original Tokenizer by adding all identified missing domain specific tokens to get an <b>Extended Tokenizer</b>. The Extended Tokenizer is then applyied to the domain specific data in order to identify high frequency tokens. Thus the Extended Tokenizer is just an intermediate step towards building a Domain Adapted Tokenizer.\n",
    "\n",
    "Finally, the Original Tokenizer is extended using only high frequency tokens to get the final <b>Domain Adapted Tokenizer</b>. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0b69b3b-aa66-42c9-b76f-2fde7f29a4b0",
   "metadata": {},
   "source": [
    "## Notebook Outline\n",
    "\n",
    "To achieve the process described above, we’ve developed a step-by-step approach that this notebook will walk you through:\n",
    "\n",
    "- Step 0: Install pre-requisites and import the required modules\n",
    "- Step 1: Download llama-2-70b embedding model and tokenizer (<b>Original Tokenizer</b>). Convert the orginal weights to trainable format and save. \n",
    "- Step 2: Train an opt-350m tokenizer from scratch using domain-specific data to get a <b>Domain Specific Tokenizer</b>.\n",
    "- Step 3: From the vocabulary of the newly trained tokenizer, identifying tokens that are absent in the general-purpose tokenizer and are rarely found in general-purpose datasets. Next, expand the general-purpose tokenizer with the newly identified tokens to get an <b>Extended Tokenizer</b>.\n",
    "- Step 4: Apply the Extended Tokenizer to the domain-specific dataset, analyze the usage frequencies of the newly-added tokens, and select the top-K tokens in a way that their cumulative frequency accounts for approximately 98% (a hyper-parameter) of the total frequency of the new tokens.\n",
    "- Step 5: Initialize the embeddings of the new tokens by utilizing the general-purpose tokenizer i.e., Original Tokenizer. When a new token is encountered, it is tokenized using the pretrained general-purpose tokenizer. The embedding and output layer weights corresponding to the new token are determined by averaging the embeddings / weights corresponding to the tokens generated by the general-purpose tokenizer.\n",
    "- Step 6: Merge the new embeddings with the original embedding table (in llama2-2-70b) to get the final <b>Domain Adapted Tokenizer</b>.\n",
    "## Data\n",
    "\n",
    "In this playbook, we will leverage chip domain/hardware datasets from open-source GitHub repositories, wiki URLs, and academic papers. Data has been processed and curated using [NeMo Curator](https://github.com/NVIDIA-NeMo/Curator/tree/dask) as shown in this [playbook](https://github.com/jvamaraju/ndc_dapt_playbook/tree/dapt_jv). Please note that this tutorial uses NeMo Curator version 0.9.0 or lower."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbee82dd",
   "metadata": {},
   "source": [
    "## NeMo Tools and Resources\n",
    "\n",
    "* [Nvidia Nemo Framework](https://github.com/NVIDIA/NeMo)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74be8ece",
   "metadata": {},
   "source": [
    "## Software Requirements\n",
    "* Access to latest NeMo Framework NGC Containers\n",
    "* This playbook has been tested on: nvcr.io/nvidia/nemo:24.07. It is expected to work similarly on other environments. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2b5ad09",
   "metadata": {},
   "source": [
    "## Hardware Requirements\n",
    "* This playbook can run on CPUs or GPUs. For GPUs, this playbook has been tested on minimum 1xA100 80G"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80bae538-308f-4d1b-8186-69de6226f3cd",
   "metadata": {},
   "source": [
    "## Step 0: install the prerequisites and import the required modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc83794a-daf9-44fb-9b89-f8cde05101a4",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "! pip install datasets sentencepiece jsonlines tokenizers transformers torch ftfy matplotlib\n",
    "! pip install protobuf==3.20.1\n",
    "! pip install --upgrade jupyter ipywidgets widgetsnbextension pandas-profiling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a91cd358",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import torch\n",
    "from datasets import Dataset\n",
    "from datasets import IterableDataset\n",
    "from datasets import load_dataset\n",
    "import jsonlines\n",
    "import glob\n",
    "from tokenizers import (\n",
    "    decoders,\n",
    "    models,\n",
    "    normalizers,\n",
    "    pre_tokenizers,\n",
    "    processors,\n",
    "    trainers,\n",
    "    Tokenizer,\n",
    ")\n",
    "from transformers import AutoTokenizer\n",
    "from tokenization_helper import *\n",
    "from extend_tokenizer_utils import extend_tokenizer, extend_tokenizer_high_freq_tokens\n",
    "from get_high_freq_tokens import get_high_freq_tokens\n",
    "from util import load_weights, merge_embed"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24d00d02",
   "metadata": {},
   "source": [
    "## Step 1: Download llama-2-70b embedding model and tokenizer (Original Tokenizer). Convert the orginal weights to trainable format and save. \n",
    "\n",
    "The Original Tokenizer model used here is the llama2 tokenizer which is a Byte Pair Encoding (BPE) model based on sentencepiece.\n",
    "\n",
    "Here we first log into the Hugging Face before downloading the model since the model is in a restricted repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38cf0264",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Install the hugging face CLI\n",
    "! pip install -U \"huggingface_hub[cli]\"\n",
    "# Generate a user access token at https://huggingface.co/settings/tokens\n",
    "\n",
    "# To download the model, please login via huggingface-cli login since it is a restricted repo\n",
    "! huggingface-cli login\n",
    "# You will be prompted to enter your User Access Token. Copy and paste the token, then press Enter. The CLI will verify the token and save it locally."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d54f21e5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# create directory for storing the downloaded hugging face model \n",
    "os.makedirs(\"models/weight/llama2-hf\", exist_ok=True)\n",
    "\n",
    "# create directories for storing the model weights \n",
    "os.makedirs(\"models/weight/llama2/ori_llama2-hf_weight\", exist_ok=True)\n",
    "os.makedirs(\"models/weight/llama2/new_llama2-hf_weight\", exist_ok=True)\n",
    "\n",
    "# create directories for storing the tokenizers\n",
    "os.makedirs(\"models/tokenizer/llama2/original_tokenizer\", exist_ok=True)\n",
    "os.makedirs(\"models/tokenizer/llama2/new_tokenizer\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "558753ac",
   "metadata": {},
   "source": [
    "Before running the next step, make sure you have access granted for Meta's Llama2 models gated group. You can fill the form available on https://huggingface.co/meta-llama/Llama-2-7b in order to get the access. (Takes ~20 minutes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7f0c3988",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# download llama2-70b model weights and tokenizer \n",
    "! huggingface-cli download meta-llama/Llama-2-70b --local-dir ./models/weight/llama2-hf/\n",
    "\n",
    "# #Copy original tokenizer to a different folder\n",
    "! cp ./models/weight/llama2-hf/tokenizer.model ./models/tokenizer/llama2/original_tokenizer\n",
    "\n",
    "# Load embedding and output layer  weights (size = (vocab size,embedding dim)) from each snapshot and create a dict\n",
    "load_path = \"./models/weight/llama2-hf\"\n",
    "save_path = './models/weight/llama2/ori_llama2-hf_weight'\n",
    "\n",
    "if not os.path.exists(save_path):\n",
    "    os.makedirs(save_path)\n",
    "    \n",
    "#load weight and store in a dictionary suitable for NeMo\n",
    "load_weights(load_path, save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6ade9a02-d38f-436c-82d0-bd16b54dbbf8",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index: 0, layer: tok_embeddings.weight, Layer size: torch.Size([32000, 1024])\n",
      "Index: 1, layer: norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 2, layer: output.weight, Layer size: torch.Size([4000, 8192])\n",
      "Index: 3, layer: layers.0.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 4, layer: layers.0.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 5, layer: layers.0.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 6, layer: layers.0.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 7, layer: layers.0.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 8, layer: layers.0.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 9, layer: layers.0.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 10, layer: layers.0.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 11, layer: layers.0.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 12, layer: layers.1.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 13, layer: layers.1.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 14, layer: layers.1.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 15, layer: layers.1.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 16, layer: layers.1.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 17, layer: layers.1.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 18, layer: layers.1.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 19, layer: layers.1.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 20, layer: layers.1.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 21, layer: layers.2.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 22, layer: layers.2.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 23, layer: layers.2.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 24, layer: layers.2.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 25, layer: layers.2.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 26, layer: layers.2.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 27, layer: layers.2.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 28, layer: layers.2.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 29, layer: layers.2.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 30, layer: layers.3.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 31, layer: layers.3.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 32, layer: layers.3.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 33, layer: layers.3.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 34, layer: layers.3.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 35, layer: layers.3.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 36, layer: layers.3.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 37, layer: layers.3.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 38, layer: layers.3.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 39, layer: layers.4.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 40, layer: layers.4.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 41, layer: layers.4.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 42, layer: layers.4.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 43, layer: layers.4.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 44, layer: layers.4.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 45, layer: layers.4.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 46, layer: layers.4.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 47, layer: layers.4.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 48, layer: layers.5.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 49, layer: layers.5.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 50, layer: layers.5.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 51, layer: layers.5.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 52, layer: layers.5.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 53, layer: layers.5.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 54, layer: layers.5.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 55, layer: layers.5.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 56, layer: layers.5.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 57, layer: layers.6.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 58, layer: layers.6.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 59, layer: layers.6.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 60, layer: layers.6.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 61, layer: layers.6.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 62, layer: layers.6.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 63, layer: layers.6.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 64, layer: layers.6.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 65, layer: layers.6.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 66, layer: layers.7.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 67, layer: layers.7.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 68, layer: layers.7.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 69, layer: layers.7.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 70, layer: layers.7.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 71, layer: layers.7.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 72, layer: layers.7.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 73, layer: layers.7.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 74, layer: layers.7.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 75, layer: layers.8.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 76, layer: layers.8.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 77, layer: layers.8.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 78, layer: layers.8.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 79, layer: layers.8.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 80, layer: layers.8.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 81, layer: layers.8.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 82, layer: layers.8.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 83, layer: layers.8.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 84, layer: layers.9.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 85, layer: layers.9.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 86, layer: layers.9.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 87, layer: layers.9.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 88, layer: layers.9.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 89, layer: layers.9.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 90, layer: layers.9.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 91, layer: layers.9.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 92, layer: layers.9.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 93, layer: layers.10.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 94, layer: layers.10.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 95, layer: layers.10.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 96, layer: layers.10.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 97, layer: layers.10.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 98, layer: layers.10.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 99, layer: layers.10.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 100, layer: layers.10.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 101, layer: layers.10.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 102, layer: layers.11.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 103, layer: layers.11.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 104, layer: layers.11.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 105, layer: layers.11.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 106, layer: layers.11.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 107, layer: layers.11.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 108, layer: layers.11.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 109, layer: layers.11.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 110, layer: layers.11.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 111, layer: layers.12.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 112, layer: layers.12.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 113, layer: layers.12.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 114, layer: layers.12.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 115, layer: layers.12.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 116, layer: layers.12.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 117, layer: layers.12.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 118, layer: layers.12.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 119, layer: layers.12.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 120, layer: layers.13.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 121, layer: layers.13.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 122, layer: layers.13.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 123, layer: layers.13.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 124, layer: layers.13.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 125, layer: layers.13.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 126, layer: layers.13.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 127, layer: layers.13.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 128, layer: layers.13.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 129, layer: layers.14.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 130, layer: layers.14.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 131, layer: layers.14.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 132, layer: layers.14.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 133, layer: layers.14.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 134, layer: layers.14.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 135, layer: layers.14.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 136, layer: layers.14.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 137, layer: layers.14.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 138, layer: layers.15.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 139, layer: layers.15.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 140, layer: layers.15.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 141, layer: layers.15.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 142, layer: layers.15.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 143, layer: layers.15.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 144, layer: layers.15.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 145, layer: layers.15.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 146, layer: layers.15.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 147, layer: layers.16.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 148, layer: layers.16.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 149, layer: layers.16.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 150, layer: layers.16.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 151, layer: layers.16.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 152, layer: layers.16.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 153, layer: layers.16.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 154, layer: layers.16.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 155, layer: layers.16.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 156, layer: layers.17.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 157, layer: layers.17.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 158, layer: layers.17.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 159, layer: layers.17.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 160, layer: layers.17.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 161, layer: layers.17.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 162, layer: layers.17.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 163, layer: layers.17.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 164, layer: layers.17.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 165, layer: layers.18.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 166, layer: layers.18.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 167, layer: layers.18.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 168, layer: layers.18.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 169, layer: layers.18.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 170, layer: layers.18.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 171, layer: layers.18.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 172, layer: layers.18.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 173, layer: layers.18.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 174, layer: layers.19.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 175, layer: layers.19.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 176, layer: layers.19.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 177, layer: layers.19.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 178, layer: layers.19.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 179, layer: layers.19.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 180, layer: layers.19.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 181, layer: layers.19.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 182, layer: layers.19.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 183, layer: layers.20.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 184, layer: layers.20.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 185, layer: layers.20.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 186, layer: layers.20.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 187, layer: layers.20.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 188, layer: layers.20.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 189, layer: layers.20.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 190, layer: layers.20.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 191, layer: layers.20.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 192, layer: layers.21.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 193, layer: layers.21.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 194, layer: layers.21.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 195, layer: layers.21.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 196, layer: layers.21.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 197, layer: layers.21.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 198, layer: layers.21.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 199, layer: layers.21.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 200, layer: layers.21.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 201, layer: layers.22.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 202, layer: layers.22.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 203, layer: layers.22.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 204, layer: layers.22.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 205, layer: layers.22.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 206, layer: layers.22.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 207, layer: layers.22.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 208, layer: layers.22.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 209, layer: layers.22.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 210, layer: layers.23.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 211, layer: layers.23.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 212, layer: layers.23.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 213, layer: layers.23.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 214, layer: layers.23.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 215, layer: layers.23.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 216, layer: layers.23.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 217, layer: layers.23.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 218, layer: layers.23.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 219, layer: layers.24.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 220, layer: layers.24.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 221, layer: layers.24.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 222, layer: layers.24.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 223, layer: layers.24.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 224, layer: layers.24.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 225, layer: layers.24.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 226, layer: layers.24.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 227, layer: layers.24.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 228, layer: layers.25.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 229, layer: layers.25.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 230, layer: layers.25.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 231, layer: layers.25.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 232, layer: layers.25.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 233, layer: layers.25.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 234, layer: layers.25.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 235, layer: layers.25.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 236, layer: layers.25.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 237, layer: layers.26.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 238, layer: layers.26.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 239, layer: layers.26.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 240, layer: layers.26.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 241, layer: layers.26.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 242, layer: layers.26.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 243, layer: layers.26.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 244, layer: layers.26.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 245, layer: layers.26.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 246, layer: layers.27.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 247, layer: layers.27.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 248, layer: layers.27.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 249, layer: layers.27.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 250, layer: layers.27.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 251, layer: layers.27.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 252, layer: layers.27.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 253, layer: layers.27.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 254, layer: layers.27.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 255, layer: layers.28.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 256, layer: layers.28.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 257, layer: layers.28.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 258, layer: layers.28.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 259, layer: layers.28.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 260, layer: layers.28.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 261, layer: layers.28.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 262, layer: layers.28.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 263, layer: layers.28.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 264, layer: layers.29.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 265, layer: layers.29.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 266, layer: layers.29.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 267, layer: layers.29.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 268, layer: layers.29.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 269, layer: layers.29.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 270, layer: layers.29.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 271, layer: layers.29.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 272, layer: layers.29.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 273, layer: layers.30.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 274, layer: layers.30.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 275, layer: layers.30.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 276, layer: layers.30.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 277, layer: layers.30.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 278, layer: layers.30.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 279, layer: layers.30.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 280, layer: layers.30.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 281, layer: layers.30.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 282, layer: layers.31.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 283, layer: layers.31.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 284, layer: layers.31.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 285, layer: layers.31.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 286, layer: layers.31.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 287, layer: layers.31.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 288, layer: layers.31.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 289, layer: layers.31.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 290, layer: layers.31.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 291, layer: layers.32.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 292, layer: layers.32.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 293, layer: layers.32.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 294, layer: layers.32.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 295, layer: layers.32.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 296, layer: layers.32.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 297, layer: layers.32.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 298, layer: layers.32.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 299, layer: layers.32.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 300, layer: layers.33.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 301, layer: layers.33.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 302, layer: layers.33.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 303, layer: layers.33.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 304, layer: layers.33.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 305, layer: layers.33.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 306, layer: layers.33.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 307, layer: layers.33.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 308, layer: layers.33.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 309, layer: layers.34.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 310, layer: layers.34.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 311, layer: layers.34.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 312, layer: layers.34.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 313, layer: layers.34.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 314, layer: layers.34.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 315, layer: layers.34.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 316, layer: layers.34.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 317, layer: layers.34.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 318, layer: layers.35.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 319, layer: layers.35.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 320, layer: layers.35.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 321, layer: layers.35.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 322, layer: layers.35.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 323, layer: layers.35.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 324, layer: layers.35.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 325, layer: layers.35.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 326, layer: layers.35.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 327, layer: layers.36.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 328, layer: layers.36.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 329, layer: layers.36.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 330, layer: layers.36.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 331, layer: layers.36.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 332, layer: layers.36.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 333, layer: layers.36.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 334, layer: layers.36.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 335, layer: layers.36.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 336, layer: layers.37.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 337, layer: layers.37.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 338, layer: layers.37.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 339, layer: layers.37.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 340, layer: layers.37.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 341, layer: layers.37.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 342, layer: layers.37.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 343, layer: layers.37.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 344, layer: layers.37.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 345, layer: layers.38.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 346, layer: layers.38.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 347, layer: layers.38.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 348, layer: layers.38.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 349, layer: layers.38.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 350, layer: layers.38.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 351, layer: layers.38.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 352, layer: layers.38.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 353, layer: layers.38.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 354, layer: layers.39.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 355, layer: layers.39.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 356, layer: layers.39.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 357, layer: layers.39.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 358, layer: layers.39.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 359, layer: layers.39.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 360, layer: layers.39.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 361, layer: layers.39.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 362, layer: layers.39.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 363, layer: layers.40.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 364, layer: layers.40.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 365, layer: layers.40.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 366, layer: layers.40.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 367, layer: layers.40.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 368, layer: layers.40.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 369, layer: layers.40.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 370, layer: layers.40.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 371, layer: layers.40.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 372, layer: layers.41.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 373, layer: layers.41.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 374, layer: layers.41.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 375, layer: layers.41.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 376, layer: layers.41.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 377, layer: layers.41.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 378, layer: layers.41.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 379, layer: layers.41.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 380, layer: layers.41.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 381, layer: layers.42.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 382, layer: layers.42.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 383, layer: layers.42.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 384, layer: layers.42.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 385, layer: layers.42.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 386, layer: layers.42.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 387, layer: layers.42.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 388, layer: layers.42.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 389, layer: layers.42.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 390, layer: layers.43.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 391, layer: layers.43.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 392, layer: layers.43.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 393, layer: layers.43.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 394, layer: layers.43.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 395, layer: layers.43.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 396, layer: layers.43.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 397, layer: layers.43.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 398, layer: layers.43.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 399, layer: layers.44.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 400, layer: layers.44.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 401, layer: layers.44.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 402, layer: layers.44.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 403, layer: layers.44.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 404, layer: layers.44.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 405, layer: layers.44.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 406, layer: layers.44.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 407, layer: layers.44.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 408, layer: layers.45.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 409, layer: layers.45.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 410, layer: layers.45.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 411, layer: layers.45.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 412, layer: layers.45.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 413, layer: layers.45.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 414, layer: layers.45.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 415, layer: layers.45.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 416, layer: layers.45.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 417, layer: layers.46.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 418, layer: layers.46.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 419, layer: layers.46.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 420, layer: layers.46.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 421, layer: layers.46.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 422, layer: layers.46.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 423, layer: layers.46.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 424, layer: layers.46.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 425, layer: layers.46.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 426, layer: layers.47.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 427, layer: layers.47.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 428, layer: layers.47.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 429, layer: layers.47.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 430, layer: layers.47.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 431, layer: layers.47.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 432, layer: layers.47.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 433, layer: layers.47.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 434, layer: layers.47.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 435, layer: layers.48.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 436, layer: layers.48.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 437, layer: layers.48.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 438, layer: layers.48.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 439, layer: layers.48.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 440, layer: layers.48.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 441, layer: layers.48.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 442, layer: layers.48.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 443, layer: layers.48.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 444, layer: layers.49.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 445, layer: layers.49.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 446, layer: layers.49.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 447, layer: layers.49.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 448, layer: layers.49.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 449, layer: layers.49.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 450, layer: layers.49.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 451, layer: layers.49.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 452, layer: layers.49.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 453, layer: layers.50.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 454, layer: layers.50.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 455, layer: layers.50.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 456, layer: layers.50.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 457, layer: layers.50.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 458, layer: layers.50.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 459, layer: layers.50.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 460, layer: layers.50.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 461, layer: layers.50.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 462, layer: layers.51.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 463, layer: layers.51.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 464, layer: layers.51.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 465, layer: layers.51.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 466, layer: layers.51.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 467, layer: layers.51.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 468, layer: layers.51.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 469, layer: layers.51.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 470, layer: layers.51.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 471, layer: layers.52.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 472, layer: layers.52.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 473, layer: layers.52.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 474, layer: layers.52.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 475, layer: layers.52.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 476, layer: layers.52.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 477, layer: layers.52.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 478, layer: layers.52.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 479, layer: layers.52.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 480, layer: layers.53.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 481, layer: layers.53.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 482, layer: layers.53.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 483, layer: layers.53.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 484, layer: layers.53.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 485, layer: layers.53.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 486, layer: layers.53.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 487, layer: layers.53.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 488, layer: layers.53.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 489, layer: layers.54.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 490, layer: layers.54.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 491, layer: layers.54.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 492, layer: layers.54.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 493, layer: layers.54.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 494, layer: layers.54.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 495, layer: layers.54.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 496, layer: layers.54.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 497, layer: layers.54.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 498, layer: layers.55.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 499, layer: layers.55.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 500, layer: layers.55.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 501, layer: layers.55.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 502, layer: layers.55.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 503, layer: layers.55.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 504, layer: layers.55.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 505, layer: layers.55.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 506, layer: layers.55.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 507, layer: layers.56.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 508, layer: layers.56.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 509, layer: layers.56.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 510, layer: layers.56.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 511, layer: layers.56.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 512, layer: layers.56.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 513, layer: layers.56.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 514, layer: layers.56.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 515, layer: layers.56.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 516, layer: layers.57.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 517, layer: layers.57.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 518, layer: layers.57.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 519, layer: layers.57.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 520, layer: layers.57.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 521, layer: layers.57.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 522, layer: layers.57.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 523, layer: layers.57.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 524, layer: layers.57.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 525, layer: layers.58.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 526, layer: layers.58.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 527, layer: layers.58.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 528, layer: layers.58.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 529, layer: layers.58.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 530, layer: layers.58.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 531, layer: layers.58.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 532, layer: layers.58.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 533, layer: layers.58.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 534, layer: layers.59.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 535, layer: layers.59.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 536, layer: layers.59.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 537, layer: layers.59.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 538, layer: layers.59.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 539, layer: layers.59.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 540, layer: layers.59.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 541, layer: layers.59.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 542, layer: layers.59.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 543, layer: layers.60.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 544, layer: layers.60.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 545, layer: layers.60.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 546, layer: layers.60.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 547, layer: layers.60.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 548, layer: layers.60.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 549, layer: layers.60.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 550, layer: layers.60.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 551, layer: layers.60.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 552, layer: layers.61.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 553, layer: layers.61.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 554, layer: layers.61.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 555, layer: layers.61.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 556, layer: layers.61.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 557, layer: layers.61.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 558, layer: layers.61.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 559, layer: layers.61.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 560, layer: layers.61.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 561, layer: layers.62.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 562, layer: layers.62.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 563, layer: layers.62.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 564, layer: layers.62.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 565, layer: layers.62.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 566, layer: layers.62.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 567, layer: layers.62.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 568, layer: layers.62.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 569, layer: layers.62.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 570, layer: layers.63.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 571, layer: layers.63.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 572, layer: layers.63.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 573, layer: layers.63.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 574, layer: layers.63.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 575, layer: layers.63.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 576, layer: layers.63.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 577, layer: layers.63.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 578, layer: layers.63.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 579, layer: layers.64.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 580, layer: layers.64.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 581, layer: layers.64.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 582, layer: layers.64.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 583, layer: layers.64.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 584, layer: layers.64.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 585, layer: layers.64.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 586, layer: layers.64.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 587, layer: layers.64.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 588, layer: layers.65.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 589, layer: layers.65.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 590, layer: layers.65.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 591, layer: layers.65.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 592, layer: layers.65.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 593, layer: layers.65.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 594, layer: layers.65.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 595, layer: layers.65.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 596, layer: layers.65.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 597, layer: layers.66.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 598, layer: layers.66.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 599, layer: layers.66.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 600, layer: layers.66.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 601, layer: layers.66.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 602, layer: layers.66.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 603, layer: layers.66.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 604, layer: layers.66.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 605, layer: layers.66.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 606, layer: layers.67.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 607, layer: layers.67.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 608, layer: layers.67.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 609, layer: layers.67.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 610, layer: layers.67.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 611, layer: layers.67.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 612, layer: layers.67.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 613, layer: layers.67.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 614, layer: layers.67.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 615, layer: layers.68.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 616, layer: layers.68.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 617, layer: layers.68.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 618, layer: layers.68.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 619, layer: layers.68.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 620, layer: layers.68.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 621, layer: layers.68.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 622, layer: layers.68.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 623, layer: layers.68.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 624, layer: layers.69.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 625, layer: layers.69.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 626, layer: layers.69.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 627, layer: layers.69.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 628, layer: layers.69.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 629, layer: layers.69.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 630, layer: layers.69.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 631, layer: layers.69.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 632, layer: layers.69.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 633, layer: layers.70.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 634, layer: layers.70.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 635, layer: layers.70.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 636, layer: layers.70.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 637, layer: layers.70.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 638, layer: layers.70.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 639, layer: layers.70.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 640, layer: layers.70.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 641, layer: layers.70.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 642, layer: layers.71.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 643, layer: layers.71.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 644, layer: layers.71.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 645, layer: layers.71.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 646, layer: layers.71.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 647, layer: layers.71.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 648, layer: layers.71.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 649, layer: layers.71.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 650, layer: layers.71.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 651, layer: layers.72.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 652, layer: layers.72.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 653, layer: layers.72.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 654, layer: layers.72.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 655, layer: layers.72.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 656, layer: layers.72.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 657, layer: layers.72.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 658, layer: layers.72.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 659, layer: layers.72.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 660, layer: layers.73.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 661, layer: layers.73.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 662, layer: layers.73.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 663, layer: layers.73.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 664, layer: layers.73.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 665, layer: layers.73.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 666, layer: layers.73.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 667, layer: layers.73.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 668, layer: layers.73.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 669, layer: layers.74.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 670, layer: layers.74.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 671, layer: layers.74.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 672, layer: layers.74.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 673, layer: layers.74.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 674, layer: layers.74.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 675, layer: layers.74.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 676, layer: layers.74.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 677, layer: layers.74.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 678, layer: layers.75.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 679, layer: layers.75.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 680, layer: layers.75.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 681, layer: layers.75.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 682, layer: layers.75.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 683, layer: layers.75.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 684, layer: layers.75.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 685, layer: layers.75.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 686, layer: layers.75.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 687, layer: layers.76.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 688, layer: layers.76.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 689, layer: layers.76.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 690, layer: layers.76.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 691, layer: layers.76.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 692, layer: layers.76.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 693, layer: layers.76.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 694, layer: layers.76.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 695, layer: layers.76.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 696, layer: layers.77.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 697, layer: layers.77.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 698, layer: layers.77.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 699, layer: layers.77.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 700, layer: layers.77.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 701, layer: layers.77.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 702, layer: layers.77.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 703, layer: layers.77.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 704, layer: layers.77.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 705, layer: layers.78.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 706, layer: layers.78.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 707, layer: layers.78.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 708, layer: layers.78.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 709, layer: layers.78.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 710, layer: layers.78.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 711, layer: layers.78.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 712, layer: layers.78.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 713, layer: layers.78.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 714, layer: layers.79.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n",
      "Index: 715, layer: layers.79.attention.wk.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 716, layer: layers.79.attention.wv.weight, Layer size: torch.Size([128, 8192])\n",
      "Index: 717, layer: layers.79.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n",
      "Index: 718, layer: layers.79.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 719, layer: layers.79.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n",
      "Index: 720, layer: layers.79.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n",
      "Index: 721, layer: layers.79.attention_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 722, layer: layers.79.ffn_norm.weight, Layer size: torch.Size([8192])\n",
      "Index: 723, layer: rope.freqs, Layer size: torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "# check layers and dimensions (optional)\n",
    "state_dict = torch.load(f\"{load_path}/consolidated.0{1}.pth\")\n",
    "for index, (key, value) in enumerate(state_dict.items()):\n",
    "    print(f\"Index: {index}, layer: {key}, Layer size: {value.size()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5a65e1f",
   "metadata": {},
   "source": [
    "## Step 2: Train a tokenizer from scratch using domain-specific data to get a Domain Specific Tokenizer."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa45b02d-2e2f-4b81-9cba-ed07cbda5b9d",
   "metadata": {},
   "source": [
    "First, we train a tokenizer from scratch using domain-specific data.\n",
    "\n",
    "The tokenizer that we use is the facebook/opt-350m model tokenizer available here on <a href=https://huggingface.co/facebook/opt-350m>hugging face</a>. Similar to the llama-2 tokenizer, opt-350m tokenizer is also a Byte Pair Encoding (BPE) model and since we are training from scratch we could use any of them. Infact, we can use any model's tokenizer that is implemented based on BPE since the training algorithm inside the tokenizer is what matters. However, we chose opt-350m since it has a more general purpose design and can be used flexibly across different tasks/domains and with various models beyond the OPT series. On the other hand llama-2 tokenizer is designed specifically for llama-2 architecture, optimizing performance for tasks that llama-2 model is intended to handle. \n",
    "\n",
    "The two hyperparameters that need to be set here are ```batch_size``` and ```vocab_size```. <br>\n",
    "\n",
    "```vocab_size``` : is the target vocab size in finetuning the tokenizer. This depends on the original tokenizer and should be slightly higher than half of the original vocab size. Note that this doesn't have to equal the number of new tokens that will be added. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "518ca72b-4ab8-4538-8eb4-42d648346347",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is a directory: True\n"
     ]
    }
   ],
   "source": [
    "data_root = \"./data/all_jsonl_data_sample/\"       # path where the domain specific data is stored\n",
    "save_root = \"./models/tokenizer/llama2/\"    # path to save the finetuned opt tokenizer\n",
    "batch_size = 1000    # batch size used in the tokenization process\n",
    "vocab_size = 20000   # target vocab size for training opt tokenizer\n",
    "\n",
    "# ensure that the directory exists before changing permissions\n",
    "directory = \"../code/\"\n",
    "is_directory = os.path.isdir(directory)\n",
    "print(f\"Is a directory: {is_directory}\")\n",
    "\n",
    "# change permissions to ensure we have read, write and execute permissions\n",
    "! chmod ugo+rwx ../code/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "519e6c13-45c3-4f52-9ed0-770d4ec62766",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before Training: \n",
      "total token cnt 66025\n",
      "\n",
      "\n",
      "\n",
      "After Training: \n",
      "total token cnt 47712\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('./models/tokenizer/llama2/custom_tokenizer_init_20000_json/tokenizer_config.json',\n",
       " './models/tokenizer/llama2/custom_tokenizer_init_20000_json/special_tokens_map.json',\n",
       " './models/tokenizer/llama2/custom_tokenizer_init_20000_json/vocab.json',\n",
       " './models/tokenizer/llama2/custom_tokenizer_init_20000_json/merges.txt',\n",
       " './models/tokenizer/llama2/custom_tokenizer_init_20000_json/added_tokens.json',\n",
       " './models/tokenizer/llama2/custom_tokenizer_init_20000_json/tokenizer.json')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train a tokenizer from scratch and save output files\n",
    "keys = [\"text\"] # keys to extract from json files\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-350m\") # load pre-trained tokenizer (https://huggingface.co/facebook/opt-350m)\n",
    "# Train the tokenizer from scratch on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline) as the current one.\n",
    "tokenizer = train_tokenizer(data_root, batch_size, vocab_size, tokenizer, keys)\n",
    "\n",
    "#Save and print paths\n",
    "tokenizer.save_pretrained(save_root + \"custom_tokenizer_init_\" + str(vocab_size) + \"_json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "127e0591-fbaa-41bc-87a5-f594587ea12d",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Step 3: From the vocabulary of the newly trained tokenizer, identify tokens that are absent in the general-purpose tokenizer and are rarely found in general-purpose datasets. Next, expand the general-purpose tokenizer with the newly identified tokens to get an extended Tokenizer.\n",
    "\n",
    "Here we expand/resize the model embeddings of the original general-purpose tokenizer with the newly identified tokens in Step 3 to get an extended tokenizer.\n",
    "\n",
    "The two hyperparemeters that need to be set here are ```split``` and ```model_type```. \n",
    "\n",
    "```split```: is the number of partitions to split the embeddings in (.pt files) for the purpose of model parallelism.\n",
    "\n",
    "```model_type``` : this is the original tokenizer model (llama2 in our case)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ec202bf4-3508-4a64-90c6-debcf116e81e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Domain vocab size: 5965\n",
      "token pattern:  [a-zA-Z]\n",
      "Num of added tokens and dropped tokens 4931 1034\n",
      "Original model pieces: 32000\n",
      "input: \"/large_experiments/theorem/datasets/MERGED/all.test1.merged\"\n",
      "model_prefix: \"spm_model_32k_200M_charcov099995_allowWSO__v2\"\n",
      "model_type: BPE\n",
      "vocab_size: 32000\n",
      "self_test_sample_size: 0\n",
      "input_format: \"text\"\n",
      "character_coverage: 0.99995\n",
      "input_sentence_size: 200000000\n",
      "seed_sentencepiece_size: 1000000\n",
      "shrinking_factor: 0.75\n",
      "num_threads: 80\n",
      "num_sub_iterations: 2\n",
      "max_sentence_length: 4192\n",
      "shuffle_input_sentence: true\n",
      "max_sentencepiece_length: 16\n",
      "split_by_unicode_script: true\n",
      "split_by_whitespace: true\n",
      "split_by_number: true\n",
      "treat_whitespace_as_suffix: false\n",
      "split_digits: true\n",
      "allow_whitespace_only_pieces: true\n",
      "vocabulary_output_piece_score: true\n",
      "hard_vocab_limit: true\n",
      "use_all_vocab: false\n",
      "byte_fallback: true\n",
      "required_chars: \"\"\n",
      "unk_id: 0\n",
      "bos_id: 1\n",
      "eos_id: 2\n",
      "pad_id: -1\n",
      "unk_surface: \" \\342\\201\\207 \"\n",
      "unk_piece: \"<unk>\"\n",
      "bos_piece: \"<s>\"\n",
      "eos_piece: \"</s>\"\n",
      "pad_piece: \"<pad>\"\n",
      "train_extremely_large_corpus: false\n",
      "enable_differential_privacy: false\n",
      "differential_privacy_noise_level: 0.0\n",
      "differential_privacy_clipping_threshold: 0\n",
      "\n",
      "original vocab size:  32000\n",
      "new token cnt:  1400\n",
      "add token cnt:  2048\n",
      "add normal token cnt:  1400\n",
      "add dummy token cnt:  648\n",
      "New model pieces: 34048\n",
      "input: \"/large_experiments/theorem/datasets/MERGED/all.test1.merged\"\n",
      "model_prefix: \"spm_model_32k_200M_charcov099995_allowWSO__v2\"\n",
      "model_type: BPE\n",
      "vocab_size: 32000\n",
      "self_test_sample_size: 0\n",
      "input_format: \"text\"\n",
      "character_coverage: 0.99995\n",
      "input_sentence_size: 200000000\n",
      "seed_sentencepiece_size: 1000000\n",
      "shrinking_factor: 0.75\n",
      "num_threads: 80\n",
      "num_sub_iterations: 2\n",
      "max_sentence_length: 4192\n",
      "shuffle_input_sentence: true\n",
      "max_sentencepiece_length: 16\n",
      "split_by_unicode_script: true\n",
      "split_by_whitespace: true\n",
      "split_by_number: true\n",
      "treat_whitespace_as_suffix: false\n",
      "split_digits: true\n",
      "allow_whitespace_only_pieces: true\n",
      "vocabulary_output_piece_score: true\n",
      "hard_vocab_limit: true\n",
      "use_all_vocab: false\n",
      "byte_fallback: true\n",
      "required_chars: \"\"\n",
      "unk_id: 0\n",
      "bos_id: 1\n",
      "eos_id: 2\n",
      "pad_id: -1\n",
      "unk_surface: \" \\342\\201\\207 \"\n",
      "unk_piece: \"<unk>\"\n",
      "bos_piece: \"<s>\"\n",
      "eos_piece: \"</s>\"\n",
      "pad_piece: \"<pad>\"\n",
      "train_extremely_large_corpus: false\n",
      "enable_differential_privacy: false\n",
      "differential_privacy_noise_level: 0.0\n",
      "differential_privacy_clipping_threshold: 0\n",
      "\n",
      "Parent directory './models/tokenizer/llama2/new_tokenizer' exists.\n",
      "Parent directory './models/tokenizer/llama2/new_tokenizer' exists.\n",
      "word_embedding shape:  torch.Size([32000, 8192])\n",
      "output_layer shape:  torch.Size([32000, 8192])\n",
      "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n",
      "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n",
      "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n",
      "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n",
      "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n",
      "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n",
      "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n",
      "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n",
      "Completed saving new embeddings\n",
      "Vocabulary path for extended tokenizer:  ./models/tokenizer/llama2/new_tokenizer/code_gen_vocab.json\n",
      "Tokenizer model path for extended tokenizer:  ./models/tokenizer/llama2/new_tokenizer/tokenizer_code_gen.model\n",
      "Modified embedding weights path for extended tokenizer:  ./models/weight/llama2/new_llama2-hf_weight/\n"
     ]
    }
   ],
   "source": [
    "split = 8      # number of partitions to split the embeddings of domain-adapted tokenizer\n",
    "model_type = \"llama2\" # Add more model_types if you want the codebase to support alternate ones\n",
    "extend_tokenizer(vocab_size, split, model_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82ddf5a4-663f-40b0-9477-bd3d3e803c12",
   "metadata": {},
   "source": [
    "## Step 4: Use the extended Tokenizer to anylze the frequency of newly added tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c78a7b5-8122-4d95-afd1-997f047c37f1",
   "metadata": {},
   "source": [
    "Here we apply the extended tokenizer to the domain-specific dataset, analyzing the usage frequencies of the newly-added tokens, and selecting the top-K tokens in a way that their cumulative frequency accounts for approximately 98% (a hyper-parameter: ```freq_threshold```) of the total frequency of the new tokens.\n",
    "\n",
    "The idea is that only high-frequency tokens will be added to the vocabulary of the original tokenizer to get the final domain adapted tokenizer. \n",
    "\n",
    "The benefits of high-frequency token analysis have been explored in several studies: ([Liu, Mingjie, et al](https://research.nvidia.com/publication/2023-10_chipnemo-domain-adapted-llms-chip-design); [Lian, Haoran, et al](https://arxiv.org/abs/2404.17808)).This is because previous studies have shown that disparities in token frequencies can result in imbalanced learning difficulties across different tokens. For instance, low frequency tokens are harder to learn for models ([Su, Zhenpeng, et al](https://arxiv.org/abs/2310.19531); [Lin, Tsung-Yi, et al](https://openaccess.thecvf.com/content_iccv_2017/html/Lin_Focal_Loss_for_ICCV_2017_paper.html)).\n",
    "\n",
    "We use two functions for frequency analysis. Helper function `analyze_token_usage` applies the extended tokenizer to domain specific data, and stores the usage/occurence frequencies of the newly-added tokens at `token_usage_path`. <br>\n",
    "\n",
    "Helper function `get_high_freq_tokens` looks at the token usage frequencies from above and performs a binary search to search for domain specific tokens with usage frequency above the specified threshold (`freq_threshold` parameter). It stores the tokens it finds at `high_freq_tokens_path`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b36a5465-b409-4d32-9a2f-1dd7c91ef917",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "split = 8      # number of partitions to split the embeddings of domain-adapted tokenizer\n",
    "model_type = \"llama2\"\n",
    "tag = \"code_gen\"\n",
    "keys = [\"text\"]\n",
    "# path to the saved extended tokenizer (from previous tep)\n",
    "extended_tokenizer_path = f\"./models/tokenizer/{model_type}/new_tokenizer/tokenizer_{tag}.model\"\n",
    "# path to save token usage frequency analysis results\n",
    "token_usage_path = f\"./models/tokenizer/{model_type}/new_tokenizer/{model_type}_token_usage.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7e2b40db-82d6-48ca-a900-145647b4dff1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab_size:  34048\n",
      "ori cnt and new cnt:  2209.0 22.0\n",
      "ori cnt and new cnt:  1764.0 20.0\n",
      "ori cnt and new cnt:  4062.0 259.0\n",
      "ori cnt and new cnt:  406.0 7.0\n",
      "ori cnt and new cnt:  1872.0 39.0\n",
      "ori cnt and new cnt:  645.0 32.0\n",
      "ori cnt and new cnt:  2655.0 20.0\n",
      "ori cnt and new cnt:  154.0 6.0\n",
      "ori cnt and new cnt:  997.0 30.0\n",
      "ori cnt and new cnt:  523.0 29.0\n",
      "ori cnt and new cnt:  523.0 29.0\n",
      "ori cnt and new cnt:  2317.0 95.0\n",
      "ori cnt and new cnt:  419.0 10.0\n",
      "ori cnt and new cnt:  813.0 13.0\n",
      "ori cnt and new cnt:  18796.0 1238.0\n",
      "ori cnt and new cnt:  3327.0 113.0\n",
      "ori cnt and new cnt:  963.0 29.0\n",
      "ori cnt and new cnt:  500.0 21.0\n",
      "ori cnt and new cnt:  610.0 22.0\n",
      "ori cnt and new cnt:  879.0 18.0\n",
      "ori cnt and new cnt:  1681.0 88.0\n",
      "ori cnt and new cnt:  654.0 16.0\n",
      "ori cnt and new cnt:  62.0 2.0\n",
      "ori cnt and new cnt:  1230.0 151.0\n",
      "ori cnt and new cnt:  786.0 40.0\n",
      "ori cnt and new cnt:  1454.0 22.0\n",
      "ori cnt and new cnt:  1237.0 29.0\n",
      "ori cnt and new cnt:  1610.0 60.0\n",
      "ori cnt and new cnt:  383.0 20.0\n",
      "ori cnt and new cnt:  766.0 22.0\n",
      "ori cnt and new cnt:  2361.0 20.0\n",
      "ori cnt and new cnt:  120.0 3.0\n",
      "ori cnt and new cnt:  714.0 31.0\n",
      "ori cnt and new cnt:  2185.0 137.0\n",
      "ori cnt and new cnt:  1270.0 75.0\n",
      "ori cnt and new cnt:  506.0 24.0\n"
     ]
    }
   ],
   "source": [
    "# analyze tokens using frequency analysis\n",
    "analyze_token_usage(data_root, extended_tokenizer_path, batch_size, keys, token_usage_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ea495ba4-8b65-4560-94a6-269e8af8a83a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# path to save selected high-frequency tokens (new tokens to be added)\n",
    "high_freq_tokens_path = f\"./models/tokenizer/{model_type}/new_tokenizer/{model_type}_freq_analy_new_token.json\"\n",
    "\n",
    "# hyperparameter \n",
    "freq_threshold = 0.98"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b816793c-a77c-4cfa-8317-f278cfdbe247",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./data/all_jsonl_data_sample/7db92aa7a05ae3eb86ec8bd0ab6e6768.lef.gz-0.jsonl\n",
      "[4 4 2 2 2 1 1 1 1 1 1 1 1] 21.56\n",
      "[4 4 2 2 2 1 1 1 1 1 1 1 1] 21.56\n",
      "3\n",
      "./data/all_jsonl_data_sample/7d3eb10b8384155f4f262b9d4a9d95b2.lef.gz-0.jsonl\n",
      "[4 2 2 2 2 2 1 1 1 1 1 1] 19.6\n",
      "[4 2 2 2 2 2 1 1 1 1 1 1] 19.6\n",
      "3\n",
      "./data/all_jsonl_data_sample/7d1e17d8e8367778544c7664a0dcca34.scala.gz-0.jsonl\n",
      "[31 30 30  7  4  3  3  3  3  3  2  2  2  2  2  2  2  2  2  2  2  2  2  2\n",
      "  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2\n",
      "  2  2  2  2  2  2  2  2  2  2  1  1  1  1  1  1  1  1  1  1  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1] 253.82\n",
      "[31 30 30  7  4  3  3  3  3  3  2  2  2  2  2  2  2  2  2  2  2  2  2  2\n",
      "  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2\n",
      "  2  2  2  2  2  2  2  2  2  2  1  1  1  1  1  1  1  1  1  1  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1] 253.82\n",
      "[31 30 30  7  4  3  3  3  3  3  2  2  2  2  2  2  2  2  2  2  2  2  2  2\n",
      "  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2\n",
      "  2  2  2  2  2  2  2  2  2  2  1  1  1  1  1  1  1  1  1  1  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1] 253.82\n",
      "[31 30 30  7  4  3  3  3  3  3  2  2  2  2  2  2  2  2  2  2  2  2  2  2\n",
      "  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2\n",
      "  2  2  2  2  2  2  2  2  2  2  1  1  1  1  1  1  1  1  1  1  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1] 253.82\n",
      "[31 30 30  7  4  3  3  3  3  3  2  2  2  2  2  2  2  2  2  2  2  2  2  2\n",
      "  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2\n",
      "  2  2  2  2  2  2  2  2  2  2  1  1  1  1  1  1  1  1  1  1  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1] 253.82\n",
      "31\n",
      "./data/all_jsonl_data_sample/7d9023bf5e97a4417b3e3d15bd0155e5.v.gz-0.jsonl\n",
      "[2 1 1 1 1 1] 6.859999999999999\n",
      "1\n",
      "./data/all_jsonl_data_sample/7d4746f7028947dbbf6f6be1e705a343.h.gz-0.jsonl\n",
      "[12  6  5  4  3  2  1  1  1  1  1  1  1] 38.22\n",
      "[12  6  5  4  3  2  1  1  1  1  1  1  1] 38.22\n",
      "[12  6  5  4  3  2  1  1  1  1  1  1  1] 38.22\n",
      "[12  6  5  4  3  2  1  1  1  1  1  1  1] 38.22\n",
      "12\n",
      "./data/all_jsonl_data_sample/7de4bd2089ed29650c6813a692c0b7fd.cdl.gz-0.jsonl\n",
      "[5 5 3 3 3 3 3 2 2 1 1 1] 31.36\n",
      "[5 5 3 3 3 3 3 2 2 1 1 1] 31.36\n",
      "4\n",
      "./data/all_jsonl_data_sample/7da75b519311e22a70fe54061b51b67c.sv.gz-0.jsonl\n",
      "[2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 19.6\n",
      "1\n",
      "./data/all_jsonl_data_sample/7d2caac63ccb0ee43f143dfad745a878.v.gz-0.jsonl\n",
      "[2 1 1 1 1] 5.88\n",
      "1\n",
      "./data/all_jsonl_data_sample/7d3ac231744dee023fffc01079a56367.v.gz-0.jsonl\n",
      "[4 3 3 3 3 3 2 2 1 1 1 1 1 1 1] 29.4\n",
      "[4 3 3 3 3 3 2 2 1 1 1 1 1 1 1] 29.4\n",
      "3\n",
      "./data/all_jsonl_data_sample/7d223aaa5ad782ba0e026a4fcd6a5e0d.v.gz-0.jsonl\n",
      "[4 3 3 3 3 3 2 2 1 1 1 1 1 1] 28.419999999999998\n",
      "[4 3 3 3 3 3 2 2 1 1 1 1 1 1] 28.419999999999998\n",
      "3\n",
      "./data/all_jsonl_data_sample/7d233e7cb17ecddca9baf0704309e739.v.gz-0.jsonl\n",
      "[4 3 3 3 3 3 2 2 1 1 1 1 1 1] 28.419999999999998\n",
      "[4 3 3 3 3 3 2 2 1 1 1 1 1 1] 28.419999999999998\n",
      "3\n",
      "./data/all_jsonl_data_sample/7de185d29809f5259616436204ae6c07.spice.gz-0.jsonl\n",
      "[21 20 20 16 16  1  1] 93.1\n",
      "[21 20 20 16 16  1  1] 93.1\n",
      "6\n",
      "./data/all_jsonl_data_sample/7d398c165432cac33c33442b2b2b9915.v.gz-0.jsonl\n",
      "[3 2 1 1 1 1 1] 9.8\n",
      "[3 2 1 1 1 1 1] 9.8\n",
      "3\n",
      "./data/all_jsonl_data_sample/7db39e3425d097664d5b3aa4800501ad.h.gz-0.jsonl\n",
      "[8 1 1 1 1 1] 12.74\n",
      "[8 1 1 1 1 1] 12.74\n",
      "6\n",
      "./data/all_jsonl_data_sample/7dbb099365a3ef31bbc60c3fc37be762.qip.gz-0.jsonl\n",
      "[143  27  26  25  20  19  19  17  17  17  17  17  17  17  17  17  16  16\n",
      "  13  12  11  11  11  10  10  10  10   9   8   8   8   7   7   7   7   7\n",
      "   6   6   6   6   6   6   6   6   6   6   6   5   5   5   5   5   5   5\n",
      "   5   5   5   5   5   4   4   4   4   4   4   4   4   4   4   4   4   4\n",
      "   4   4   4   4   4   4   4   4   4   4   4   3   3   3   3   3   3   3\n",
      "   3   3   3   3   3   3   3   3   3   3   3   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1] 1213.24\n",
      "[143  27  26  25  20  19  19  17  17  17  17  17  17  17  17  17  16  16\n",
      "  13  12  11  11  11  10  10  10  10   9   8   8   8   7   7   7   7   7\n",
      "   6   6   6   6   6   6   6   6   6   6   6   5   5   5   5   5   5   5\n",
      "   5   5   5   5   5   4   4   4   4   4   4   4   4   4   4   4   4   4\n",
      "   4   4   4   4   4   4   4   4   4   4   4   3   3   3   3   3   3   3\n",
      "   3   3   3   3   3   3   3   3   3   3   3   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1] 1213.24\n",
      "[143  27  26  25  20  19  19  17  17  17  17  17  17  17  17  17  16  16\n",
      "  13  12  11  11  11  10  10  10  10   9   8   8   8   7   7   7   7   7\n",
      "   6   6   6   6   6   6   6   6   6   6   6   5   5   5   5   5   5   5\n",
      "   5   5   5   5   5   4   4   4   4   4   4   4   4   4   4   4   4   4\n",
      "   4   4   4   4   4   4   4   4   4   4   4   3   3   3   3   3   3   3\n",
      "   3   3   3   3   3   3   3   3   3   3   3   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1] 1213.24\n",
      "[143  27  26  25  20  19  19  17  17  17  17  17  17  17  17  17  16  16\n",
      "  13  12  11  11  11  10  10  10  10   9   8   8   8   7   7   7   7   7\n",
      "   6   6   6   6   6   6   6   6   6   6   6   5   5   5   5   5   5   5\n",
      "   5   5   5   5   5   4   4   4   4   4   4   4   4   4   4   4   4   4\n",
      "   4   4   4   4   4   4   4   4   4   4   4   3   3   3   3   3   3   3\n",
      "   3   3   3   3   3   3   3   3   3   3   3   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1] 1213.24\n",
      "[143  27  26  25  20  19  19  17  17  17  17  17  17  17  17  17  16  16\n",
      "  13  12  11  11  11  10  10  10  10   9   8   8   8   7   7   7   7   7\n",
      "   6   6   6   6   6   6   6   6   6   6   6   5   5   5   5   5   5   5\n",
      "   5   5   5   5   5   4   4   4   4   4   4   4   4   4   4   4   4   4\n",
      "   4   4   4   4   4   4   4   4   4   4   4   3   3   3   3   3   3   3\n",
      "   3   3   3   3   3   3   3   3   3   3   3   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1] 1213.24\n",
      "[143  27  26  25  20  19  19  17  17  17  17  17  17  17  17  17  16  16\n",
      "  13  12  11  11  11  10  10  10  10   9   8   8   8   7   7   7   7   7\n",
      "   6   6   6   6   6   6   6   6   6   6   6   5   5   5   5   5   5   5\n",
      "   5   5   5   5   5   4   4   4   4   4   4   4   4   4   4   4   4   4\n",
      "   4   4   4   4   4   4   4   4   4   4   4   3   3   3   3   3   3   3\n",
      "   3   3   3   3   3   3   3   3   3   3   3   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1] 1213.24\n",
      "[143  27  26  25  20  19  19  17  17  17  17  17  17  17  17  17  16  16\n",
      "  13  12  11  11  11  10  10  10  10   9   8   8   8   7   7   7   7   7\n",
      "   6   6   6   6   6   6   6   6   6   6   6   5   5   5   5   5   5   5\n",
      "   5   5   5   5   5   4   4   4   4   4   4   4   4   4   4   4   4   4\n",
      "   4   4   4   4   4   4   4   4   4   4   4   3   3   3   3   3   3   3\n",
      "   3   3   3   3   3   3   3   3   3   3   3   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2\n",
      "   2   2   2   2   2   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1\n",
      "   1   1   1   1   1   1   1] 1213.24\n",
      "142\n",
      "./data/all_jsonl_data_sample/7dc1209d13f0e65aab95b30e28fdc7b0.spice.gz-0.jsonl\n",
      "[29 24 24 15 11  7  1  1  1] 110.74\n",
      "[29 24 24 15 11  7  1  1  1] 110.74\n",
      "7\n",
      "./data/all_jsonl_data_sample/7dddcc0a609031cc37396af611abd521.v.gz-0.jsonl\n",
      "[7 7 7 2 2 2 1 1] 28.419999999999998\n",
      "[7 7 7 2 2 2 1 1] 28.419999999999998\n",
      "[7 7 7 2 2 2 1 1] 28.419999999999998\n",
      "7\n",
      "./data/all_jsonl_data_sample/7dbf14d1da77bcf50408a3548fba5443.v.gz-0.jsonl\n",
      "[4 3 3 3 2 2 1 1 1 1] 20.58\n",
      "[4 3 3 3 2 2 1 1 1 1] 20.58\n",
      "3\n",
      "./data/all_jsonl_data_sample/7d1585b1aef10fb448a5b5b8fbd0b624.v.gz-0.jsonl\n",
      "[3 3 3 2 2 2 2 2 2 1] 21.56\n",
      "[3 3 3 2 2 2 2 2 2 1] 21.56\n",
      "3\n",
      "./data/all_jsonl_data_sample/7d0e56309d51283206ed83074b1ccf76.sv.gz-0.jsonl\n",
      "[4 4 1 1 1 1 1 1 1 1 1 1] 17.64\n",
      "[4 4 1 1 1 1 1 1 1 1 1 1] 17.64\n",
      "3\n",
      "./data/all_jsonl_data_sample/7dab8d5649d18c7a9c155b51392a6588.v.gz-0.jsonl\n",
      "[29 18  4  4  4  4  4  3  3  3  2  2  2  1  1  1  1  1  1] 86.24\n",
      "[29 18  4  4  4  4  4  3  3  3  2  2  2  1  1  1  1  1  1] 86.24\n",
      "[29 18  4  4  4  4  4  3  3  3  2  2  2  1  1  1  1  1  1] 86.24\n",
      "18\n",
      "./data/all_jsonl_data_sample/7dbb8e6199e137dcf3d7085bb0ff9975.v.gz-0.jsonl\n",
      "[11  3  1  1] 15.68\n",
      "[11  3  1  1] 15.68\n",
      "4\n",
      "./data/all_jsonl_data_sample/7d62d6bd44676d9f2ea5cdbbd594c4ef.c.gz-0.jsonl\n",
      "[2] 1.96\n",
      "2\n",
      "./data/all_jsonl_data_sample/7dbfcd8e236b3b802c78e9ab57b3a1d0.scala.gz-0.jsonl\n",
      "[63 11  7  7  5  4  4  3  3  3  3  3  2  2  2  2  2  2  2  2  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1] 147.98\n",
      "[63 11  7  7  5  4  4  3  3  3  3  3  2  2  2  2  2  2  2  2  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1] 147.98\n",
      "[63 11  7  7  5  4  4  3  3  3  3  3  2  2  2  2  2  2  2  2  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1] 147.98\n",
      "[63 11  7  7  5  4  4  3  3  3  3  3  2  2  2  2  2  2  2  2  1  1  1  1\n",
      "  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1] 147.98\n",
      "36\n",
      "./data/all_jsonl_data_sample/7d7b96f51a734da259a6a2ecf379cded.cdl.gz-0.jsonl\n",
      "[6 6 4 4 4 4 3 2 2 2 1 1 1] 39.2\n",
      "[6 6 4 4 4 4 3 2 2 2 1 1 1] 39.2\n",
      "[6 6 4 4 4 4 3 2 2 2 1 1 1] 39.2\n",
      "6\n",
      "./data/all_jsonl_data_sample/7d23810b472d58f3487c52b1f773189a.lef.gz-0.jsonl\n",
      "[4 3 2 2 2 1 1 1 1 1 1 1 1 1] 21.56\n",
      "[4 3 2 2 2 1 1 1 1 1 1 1 1 1] 21.56\n",
      "3\n",
      "./data/all_jsonl_data_sample/7dc70412013409c8b16c0c9e5f14fcfa.v.gz-0.jsonl\n",
      "[7 7 7 2 2 2 1 1] 28.419999999999998\n",
      "[7 7 7 2 2 2 1 1] 28.419999999999998\n",
      "[7 7 7 2 2 2 1 1] 28.419999999999998\n",
      "7\n",
      "./data/all_jsonl_data_sample/7d5dd4296bf9ada66d63916f69a46faf.emf.gz-0.jsonl\n",
      "[20 14  2  2  2  2  2  2  2  2  2  1  1  1  1  1  1  1  1] 58.8\n",
      "[20 14  2  2  2  2  2  2  2  2  2  1  1  1  1  1  1  1  1] 58.8\n",
      "[20 14  2  2  2  2  2  2  2  2  2  1  1  1  1  1  1  1  1] 58.8\n",
      "18\n",
      "./data/all_jsonl_data_sample/7dd8f2f49230dee32da3142ac3984412.v.gz-0.jsonl\n",
      "[3 2 2 2 2 2 1 1 1 1 1 1 1] 19.6\n",
      "[3 2 2 2 2 2 1 1 1 1 1 1 1] 19.6\n",
      "3\n",
      "./data/all_jsonl_data_sample/7d43a09a7438300752244fb0d9fb05e8.v.gz-0.jsonl\n",
      "[3 3 3 2 2 2 2 2 2 1] 21.56\n",
      "[3 3 3 2 2 2 2 2 2 1] 21.56\n",
      "3\n",
      "./data/all_jsonl_data_sample/7de4d2d50bb4424b0fe35bae7a83be7b.lef.gz-0.jsonl\n",
      "[4 2 2 2 2 1 1 1 1 1 1 1 1] 19.6\n",
      "[4 2 2 2 2 1 1 1 1 1 1 1 1] 19.6\n",
      "3\n",
      "./data/all_jsonl_data_sample/7d5db7ecf8c09e4d6872e9998d3ffc4c.v.gz-0.jsonl\n",
      "[2 1] 2.94\n",
      "1\n",
      "./data/all_jsonl_data_sample/7d89dd77dcd8779cd013cdad4527558c.h.gz-0.jsonl\n",
      "[3 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 30.38\n",
      "[3 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 30.38\n",
      "3\n",
      "./data/all_jsonl_data_sample/7d5a331f93e5a37c5e367230ca1c1a14.cdl.gz-0.jsonl\n",
      "[16 14 14 14 11 11  8  7  7  5  5  4  3  3  2  2  2  2  2  2  1  1  1] 134.26\n",
      "[16 14 14 14 11 11  8  7  7  5  5  4  3  3  2  2  2  2  2  2  1  1  1] 134.26\n",
      "[16 14 14 14 11 11  8  7  7  5  5  4  3  3  2  2  2  2  2  2  1  1  1] 134.26\n",
      "[16 14 14 14 11 11  8  7  7  5  5  4  3  3  2  2  2  2  2  2  1  1  1] 134.26\n",
      "15\n",
      "./data/all_jsonl_data_sample/7deb467f7b1a328162b5b8ae171ca139.scala.gz-0.jsonl\n",
      "[30  7  6  6  6  4  3  3  2  2  2  1  1  1  1] 73.5\n",
      "[30  7  6  6  6  4  3  3  2  2  2  1  1  1  1] 73.5\n",
      "[30  7  6  6  6  4  3  3  2  2  2  1  1  1  1] 73.5\n",
      "[30  7  6  6  6  4  3  3  2  2  2  1  1  1  1] 73.5\n",
      "14\n",
      "./data/all_jsonl_data_sample/7d146ea4e04027f987bc4c8d1bf2326e.cdl.gz-0.jsonl\n",
      "[4 4 3 2 2 2 2 2 1 1 1] 23.52\n",
      "[4 4 3 2 2 2 2 2 1 1 1] 23.52\n",
      "3\n"
     ]
    }
   ],
   "source": [
    "# selecting the top-K tokens in a way that their cumulative frequency accounts for approximately 98%\n",
    "get_high_freq_tokens(token_usage_path, high_freq_tokens_path, float(freq_threshold))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b040dc93-6667-4c8c-91a7-dd715151bbc6",
   "metadata": {},
   "source": [
    "## Step 5:  Initialize the embeddings of the new tokens by utilizing the extended general-purpose tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2e64e1d-fb92-4cd5-b19a-ff411fb231d7",
   "metadata": {},
   "source": [
    "Here we use the `extend_tokenizer` helper fucntion to first add high freq. tokens identified in Step 4 to original tokenizer vocab.​\n",
    "\n",
    "Both the embedding table and the output layer weights of the original tokenizer depend on the vocab size. Since the vocab size is now changed due to addition of high freq. domain specific tokens, both of these need to be updated.\n",
    "\n",
    "`extend_sentencepiece` initializes the embeddings of the new tokens by utilizing the general-purpose tokenizer. When a new token (a word or subword unit) is encountered, it is first broken down (tokenized) using the pretrained general-purpose tokenizer. \n",
    "\n",
    "The new token doesn’t have a predefined embedding (a numerical representation). The embedding of the new token is determined by averaging the embeddings of the tokens generated by the general-purpose tokenizer. For example, if the new token is split into three sub-tokens, the embeddings of these three sub-tokens are averaged to form the embedding of the new token.\n",
    "\n",
    "Similarly, the weights in the output layer corresponding to the new token are also initialized to the average of the tokens generated by the general-purpose tokenizer. For example, if the new token is split into three sub-tokens, the weights corresponding to these three sub-tokens are averaged to form the weights corresponding to the new token.\n",
    "\n",
    "Once done, in Step 6 we will merge the new embeddings with the original embedding table (in llama2) to get the final Domain Adapted Tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2644f313-66d8-47d0-9fc0-f1b2edc72d79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ori_tokenizer_path = f\"./models/tokenizer/{model_type}/original_tokenizer/tokenizer.model\" # original sentencepiece tokenizer model\n",
    "new_vocab_path = f\"./models/tokenizer/{model_type}/new_tokenizer/freq_vocab.json\" # path to record added new tokens\n",
    "old_ebd_path = f\"./models/weight/{model_type}/ori_{model_type}-hf_weight/\" # original embeddings\n",
    "new_ebd_path = f\"./models/weight/{model_type}/new_{model_type}-hf_weight/\" # path to store augmented embeddings\n",
    "domain_adapter_tokenizer_path = f\"./models/tokenizer/{model_type}/new_tokenizer/tokenizer_freq.model\" # augmented sentencepiece model\n",
    "split = 8 # num of partitions to split the augmented embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3946599e-43b6-4391-b6ab-0068f9f93113",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "f = open(high_freq_tokens_path, \"r\")\n",
    "new_tokens = json.load(f)\n",
    "print(\"new_tokens: \", new_tokens)\n",
    "extend_tokenizer_high_freq_tokens(data_root, ori_tokenizer_path, new_tokens, new_vocab_path, domain_adapter_tokenizer_path, old_ebd_path, new_ebd_path, split)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fda5c5ad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(new_ebd_path) #New weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27f8da73",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(domain_adapter_tokenizer_path) # domained adapted tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "227c5b66-bb02-4fc6-96c3-c4284d2f6e99",
   "metadata": {},
   "source": [
    "# Step 6:  Merge the new embeddings with the original embedding table (in llama2) to get the final Domain Adapted Tokenizer and Embeddings."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05999ff2",
   "metadata": {},
   "source": [
    "Helper function `merge_embed` takes the original embeddings downloaded from hugging face, and the augmented embeddings generated in Step 5 above, merges them and then saves the result at `save_path`.\n",
    "\n",
    "For instance, figure below shows an illustration of embedding table modification. Here each row corresponds to a unique token and each column represents a dimension of the embedding vector. The size of the vocabulary determines the number of rows in the embedding table. The embedding layer in the LLM which is responsible for converting the data into numerical vectors uses the embedding table to perform this conversion. The dimensionality of the embedding layer is given by the number of columns in the embedding table. <br>\n",
    "\n",
    "![pipeline](imgs/embedding_table.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21658fb2-a54a-41cf-94fc-81735767cdab",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(f\"/models/weight/new_merged_{model_type}-hf\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c616e7f",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "old_ebd_path = f\"./models/weight/{model_type}-hf\" # original embeddings downloaded from hf\n",
    "new_ebd_path = f\"./models/weight/{model_type}/new_{model_type}-hf_weight\" # augmented embeddings\n",
    "save_path = f\"./models/weight/new_merged_{model_type}-hf\" # Path to adapted llama2 weights\n",
    "merge_embed(old_ebd_path, new_ebd_path, save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9060868f",
   "metadata": {},
   "source": [
    "### New weights and tokenizer are stored at:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca159e0c-28b2-4a64-bdf4-314e4191c2a0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(new_ebd_path) #New weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d657312-814a-4ad8-a46f-c9c7bc4ee978",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(domain_adapter_tokenizer_path) # domained adapted tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0782e9c",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# check layers and dimensions (optional)\n",
    "state_dict = torch.load(f'{save_path}/consolidated.01.pth')\n",
    "for index, (key, value) in enumerate(state_dict.items()):\n",
    "    print(f\"Index: {index}, layer: {key}, Layer size: {value.size()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15a172e2",
   "metadata": {},
   "source": [
    "# Next Step\n",
    "\n",
    "The final Domain adapted Tokenizer obtained using this notebook can be used in a continual pre-training pipeline for domain adaptive pretraining."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
